/*!
 * Copyright (c) Alibaba, Inc. and its affiliates.
 * @file    test_common.h
 */
#pragma once
#include <gtest/gtest.h>

#if ENABLE_FP16
#include <cuda_fp16.h>
#endif
#if ENABLE_BF16
#include <hie_bfloat16.hpp>
#endif

#include <algorithm>
#include <cmath>
#include <random>
#include <thread>

template <typename T>
float check_equal_vec(const std::vector<T>& a, const std::vector<T>& b,
                      bool absolute = false) {
  float max_eps = 0;
  if (a.size() != b.size()) {
    printf("vector compare with different size, a:%ld b:%ld\n", a.size(),
           b.size());
    throw std::invalid_argument("not equal sized vector");
  }

  int count = std::min(a.size(), b.size());
  for (int i = 0; i < count; ++i) {
    float eps = 0;

    if (absolute) {
      eps = std::fabs((float)(a[i] - b[i]));
    } else {
      eps = std::min(std::fabs((float)(b[i] - a[i])),
                     std::fabs((float)(a[i] - b[i]) / (float)(b[i])));
    }
    if (std::is_same<T, float>::value && std::isnan(eps)) {
      eps = std::numeric_limits<float>::infinity();
    }
    if (max_eps < eps) {
      max_eps = eps;
    }
  }
  return max_eps;
}

template <typename T>
float check_equal(const void* data_a, const void* data_b, size_t count,
                  bool absolute = false, int nworkers = 32) {
  const T* ptrA = (const T*)data_a;
  const T* ptrB = (const T*)data_b;
  auto check_func = [&ptrA, &ptrB, absolute](float* max_eps, size_t start,
                                             size_t end, int worker_id) {
    *max_eps = 0;
    for (size_t i = start; i < end; ++i) {
      float eps = 0;
      if (absolute) {
        eps = std::fabs(float(ptrA[i]) - float(ptrB[i]));
      } else {
        eps = std::min(
            std::fabs(float(ptrA[i]) - float(ptrB[i])),
            std::fabs((float(ptrA[i]) - float(ptrB[i])) / float(ptrB[i])));
#if 0
        printf("worker %d, ref: %f, out: %f, eps: %f\n", worker_id,
               float(ptrA[i]), float(ptrB[i]), eps);
#endif
      }
      if (!absolute && std::isnan(eps)) {
        eps = std::numeric_limits<float>::infinity();
      }
      if (*max_eps < eps) {
        *max_eps = eps;
      }
    }
    return;
  };

  std::vector<float> max_eps_vec(nworkers);
  std::vector<std::unique_ptr<std::thread>> thds(nworkers);
  size_t chunk_size = count / nworkers;
  for (int i = 0; i < nworkers - 1; ++i) {
    thds[i] =
        std::make_unique<std::thread>(check_func, max_eps_vec.data() + i,
                                      i * chunk_size, (i + 1) * chunk_size, i);
  }
  thds[nworkers - 1] = std::make_unique<std::thread>(
      check_func, max_eps_vec.data() + nworkers - 1,
      (nworkers - 1) * chunk_size, count, nworkers - 1);

  for (int i = 0; i < nworkers; ++i) {
    thds[i]->join();
  }

  float max_eps = *std::max_element(max_eps_vec.begin(), max_eps_vec.end());
  return max_eps;
}

namespace AS_UTEST {

/**
 * @brief Generate Random data
 *
 */
template <typename T>
void generate_random_data_impl(T* data, size_t size, T lower_range,
                               T upper_range, int seed, ...);

template <typename T, typename std::enable_if<
                          std::is_integral<T>::value &&
                          !std::is_same<bool, T>::value>::type* = nullptr>
void generate_random_data_impl(T* data, size_t size, T lower_range,
                               T upper_range, int seed, int) {
  std::default_random_engine generator;
  if (seed != 0) {
    generator.seed(seed);
  } else {
    std::random_device rd;
    generator.seed(rd());
  }
  std::uniform_int_distribution<T> uni(lower_range, upper_range);
  std::generate(data, data + size, [&]() { return uni(generator); });
}

template <typename T, typename std::enable_if<
                          std::is_same<bool, T>::value>::type* = nullptr>
void generate_random_data_impl(T* data, size_t size, T lower_range,
                               T upper_range, int seed, int) {
  bool* data_bool = static_cast<bool*>(data);
  std::default_random_engine generator;
  if (seed != 0) {
    generator.seed(seed);
  } else {
    std::random_device rd;
    generator.seed(rd());
  }
  std::uniform_int_distribution<int> uni(0, 1);
  std::generate(data_bool, data_bool + size,
                [&]() { return uni(generator) != 0 ? true : false; });
}

template <typename T, typename std::enable_if<
                          std::is_same<T, float>::value ||
                          std::is_same<T, double>::value>::type* = nullptr>
void generate_random_data_impl(T* data, size_t size, T lower_range,
                               T upper_range, int seed, int) {
  std::default_random_engine generator;
  if (seed != 0) {
    generator.seed(seed);
  } else {
    std::random_device rd;
    generator.seed(rd());
  }
  std::uniform_real_distribution<T> uni(lower_range, upper_range);
  std::generate(data, data + size, [&]() { return uni(generator); });
}

struct __invalid_type {};

template <typename T,
          typename std::enable_if<
#if ENABLE_FP16
              std::is_same<T, half>::value ||
#endif
#if ENABLE_BF16
              std::is_same<T, hie::bfloat16>::value ||
#endif
              std::is_same<T, __invalid_type>::value>::type* = nullptr>
void generate_random_data_impl(T* data, size_t size, T lower_range,
                               T upper_range, int seed, int) {
  std::vector<float> tmp(size);
  generate_random_data_impl<float>(
      tmp.data(), size, static_cast<float>(lower_range),
      static_cast<float>(upper_range), seed, int{});
  for (size_t i = 0; i < size; i++) {
    data[i] = static_cast<T>(tmp[i]);
  }
}

template <typename T>
void generate_random_data(T* data, size_t size, T lower_range, T upper_range,
                          int seed = 0) {
  generate_random_data_impl(data, size, lower_range, upper_range, seed, int{});
}

template <typename T>
void generate_random_data(std::vector<T>& data, size_t size, T lower_range,
                          T upper_range, int seed = 0) {
  data.clear();
  data.resize(size);
  generate_random_data(data.data(), size, lower_range, upper_range, seed);
}

template <typename T>
void generate_random_data(T* data, size_t size, T range = T(1)) {
  generate_random_data(data, size, T(-range), range, 0);
}

template <typename T>
void generate_random_data(std::vector<T>& data, size_t size, T range = T(1)) {
  data.clear();
  data.resize(size);
  generate_random_data(data.data(), size, range);
}

}  // namespace AS_UTEST
